{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 微调和部署大语言模型ChatGLM2-6b\n",
    "\n",
    "大型语言模型（Large Language Model, LLM）是基于大量数据进行预训练的超大型语言模型，通过在大量文本数据上进行训练，这些模型能够学习语言的复杂结构和细微含义，并在从文本生成、语言理解等任务上展现出卓越的性能。\n",
    "\n",
    "PAI 提供了许多热门的开源大语言模型，支持用户开箱即用的方式使用这些模型进行部署或是使用私有数据进行微调训练。当前示例，我们将以ChatGLM2-6b为例，介绍如何使用PAI提供的大语言模型进行微调和部署。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## 费用说明\n",
    "\n",
    "本示例将会使用以下云产品，并产生相应的费用账单：\n",
    "\n",
    "- PAI-DLC：运行训练任务，详细计费说明请参考[PAI-DLC计费说明](https://help.aliyun.com/zh/pai/product-overview/billing-of-dlc)\n",
    "- PAI-EAS：部署推理服务，详细计费说明请参考[PAI-EAS计费说明](https://help.aliyun.com/zh/pai/product-overview/billing-of-eas)\n",
    "- OSS：存储训练任务输出的模型、训练代码、TensorBoard日志等，详细计费说明请参考[OSS计费概述](https://help.aliyun.com/zh/oss/product-overview/billing-overview)\n",
    "\n",
    "\n",
    "> 通过参与云产品免费试用，使用**指定资源机型**提交训练作业或是部署推理服务，可以免费试用PAI产品，具体请参考[PAI免费试用](https://help.aliyun.com/zh/pai/product-overview/free-quota-for-new-users)。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## 准备工作\n",
    "\n",
    "通过以下命令安装PAI Python SDK。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -m pip install --upgrade pai"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "SDK 需要配置访问阿里云服务需要的 AccessKey，以及当前使用的工作空间和OSS Bucket。在 PAI Python SDK 安装之后，通过在 **命令行终端** 中执行以下命令，按照引导配置密钥，工作空间等信息。\n",
    "\n",
    "\n",
    "```shell\n",
    "\n",
    "# 以下命令，请在 命令行终端 中执行.\n",
    "\n",
    "python -m pai.toolkit.config\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.session import get_default_session, setup_default_session\n",
    "\n",
    "# 用户通过以下的方式配置SDK的session.\n",
    "sess = get_default_session()\n",
    "if not sess:\n",
    "    print(\"config session\")\n",
    "    sess = setup_default_session(\n",
    "        region_id=\"<REGION_ID>\",  # 例如：cn-beijing\n",
    "        workspace_id=\"<WORKSPACE_ID>\",  # 例如：12345\n",
    "        oss_bucket_name=\"<OSS_BUCKET_NAME>\",\n",
    "    )\n",
    "# 将当前的配置持久化到 ~/.pai/config.json，SDK默认从对应的路径读取配置初始化默认session。\n",
    "sess.save_config()\n",
    "\n",
    "print(sess.workspace_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "我们可以通过以下代码验证当前的配置。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pai\n",
    "from pai.session import get_default_session\n",
    "\n",
    "print(pai.__version__)\n",
    "sess = get_default_session()\n",
    "print(sess.workspace_id)\n",
    "print(sess.workspace_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PAI支持的LLM模型\n",
    "\n",
    "通过`RegisteredModel.list`方法，使用`task=\"large-language-model\"`，`model_provider=\"pai\"`，可以过滤获得PAI支持的所有的大语言模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table><tr><td style=\"text-align:left;\">ModelName</td><td style=\"text-align:left;\">ModelVersion</td><td style=\"text-align:left;\">ModelProvider</td><td style=\"text-align:left;\">Task</td><td style=\"text-align:left;\">Labels</td><td style=\"text-align:left;\">SupportTraining</td><td style=\"text-align:left;\">SupportDeploy</td></tr><tr><td style=\"text-align:left;\">baichuan2-7b-chat</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">disabled=True, framework=PyTorch, language=Multilingual, model_size=7000000000, source=ModelScope</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">baichuan2-7b-base</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">disabled=True, framework=PyTorch, language=Multilingual, model_size=7000000000, source=ModelScope</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">qwen-7b-chat</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">disabled=True, framework=PyTorch, language=Multilingual, model_size=7000000000, source=ModelScope</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">qwen-7b</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">disabled=True, framework=PyTorch, language=Multilingual, model_size=7000000000, source=ModelScope</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">llama-2-7b-chat</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=7000000000, source=ModelScope</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">llama-2-7b</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=7000000000, source=ModelScope</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">falcon-7b-instruct</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=7000000000, source=ModelScope</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">falcon-7b</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=7000000000, source=ModelScope</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">dolly-v2-7b</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=7000000000, source=ModelScope</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">dolly-v2-3b</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=3000000000, source=ModelScope</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">chatglm2-6b</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Chinese, model_size=6000000000, source=ModelScope</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">aquilachat-7b</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">disabled=True, framework=PyTorch, language=Multilingual, model_size=7000000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">aquila-7b</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">disabled=True, framework=PyTorch, language=Multilingual, model_size=7000000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">bloom-7b1</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=3000000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">bloom-3b</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=3000000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">bloom-1b7</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=1700000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">qwen-7b-chat-lora</td><td style=\"text-align:left;\">0.3.0</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multilingual, model_size=7000000000, source=ModelScope</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">falcon-7b-instruct-lora</td><td style=\"text-align:left;\">0.3.0</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=7000000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">falcon-7b-lora</td><td style=\"text-align:left;\">0.3.0</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=7000000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">dolly-v2-7b-lora</td><td style=\"text-align:left;\">0.3.0</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=English, model_size=7000000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">dolly-v2-3b-lora</td><td style=\"text-align:left;\">0.3.0</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=English, model_size=2800000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">aquilachat-7b-lora</td><td style=\"text-align:left;\">0.3.0</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Chinese, model_size=7000000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">aquila-7b-lora</td><td style=\"text-align:left;\">0.3.0</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Chinese, model_size=7000000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">chatglm2-6b-lora</td><td style=\"text-align:left;\">0.3.0</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Chinese, model_size=6000000000, source=ModelScope</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">bloom-7b1-lora</td><td style=\"text-align:left;\">0.3.0</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=7100000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">bloom-3b-lora</td><td style=\"text-align:left;\">0.3.0</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=3000000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">bloom-1b7-lora</td><td style=\"text-align:left;\">0.3.0</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=1700000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">bloom-1b1-lora</td><td style=\"text-align:left;\">0.3.0</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=1100000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">llama-2-7b-chat-hf</td><td style=\"text-align:left;\">0.3.0</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=7000000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">llama-2-7b-hf</td><td style=\"text-align:left;\">0.3.0</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=7000000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr><tr><td style=\"text-align:left;\">bloom-1b1</td><td style=\"text-align:left;\">0.2.1</td><td style=\"text-align:left;\">pai</td><td style=\"text-align:left;\">large-language-model</td><td style=\"text-align:left;\">framework=PyTorch, language=Multi-lingual, model_size=1100000000, source=HuggingFace</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">True</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from pai.model import RegisteredModel\n",
    "from IPython.display import display, HTML\n",
    "\n",
    "\n",
    "def table_display(data):\n",
    "    content = \"\".join(\n",
    "        \"<tr>{}</tr>\".format(\n",
    "            \"\".join(\n",
    "                '<td style=\"text-align:left;\">{}</td>'.format(column) for column in row\n",
    "            )\n",
    "        )\n",
    "        for row in data\n",
    "    )\n",
    "    display(HTML(f\"<table>{content}</table>\"))\n",
    "\n",
    "\n",
    "data = [\n",
    "    [\n",
    "        \"ModelName\",\n",
    "        \"ModelVersion\",\n",
    "        \"ModelProvider\",\n",
    "        \"Task\",\n",
    "        \"Labels\",\n",
    "        \"SupportTraining\",\n",
    "        \"SupportDeploy\",\n",
    "    ]\n",
    "]\n",
    "for m in RegisteredModel.list(model_provider=\"pai\", task=\"large-language-model\"):\n",
    "    data.append(\n",
    "        [\n",
    "            m.model_name,\n",
    "            m.model_version,\n",
    "            m.model_provider,\n",
    "            m.task,\n",
    "            \", \".join([\"{}={}\".format(k, v) for k, v in m.version_labels.items()]),\n",
    "            bool(m.training_spec),\n",
    "            bool(m.inference_spec),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "\n",
    "table_display(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 部署ChatGLM模型\n",
    "\n",
    "\n",
    "PAI提供的这些大语言模型，支持用户直接部署到PAI上，创建一个在线服务。用户可以通过`RegisteredModel.deploy`方法，将模型部署到PAI上，创建一个专属的大语言模型在线服务。\n",
    "\n",
    "以下我们将以ChatGLM2-6b为示例，介绍如何部署一个大语言模型在线服务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.model import RegisteredModel\n",
    "\n",
    "# 获取PAI提供的ChatGLM2-6b模型\n",
    "m = RegisteredModel(model_name=\"chatglm2-6b\", model_provider=\"pai\")\n",
    "\n",
    "\n",
    "# 查看模型的部署配置\n",
    "print(m.inference_spec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.common.utils import random_str\n",
    "\n",
    "# 部署模型服务\n",
    "predictor = m.deploy(\n",
    "    \"chatglm2_6b_{}\".format(random_str(8)),  # 推理服务名称\n",
    "    #  instance_type=\"ecs.gn6v-c8g1.2xlarge\",   # 配置使用的机器实例规格\n",
    "    #  instance_count=1,                        # 配置机器实例个数\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "部署模型服务返回`Predictor`对象，可以用于调用在线服务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'response': \"To install PyTorch, you can use pip, which is the package installer for Python. Here are the steps to install PyTorch using pip:\\n\\n1. Open a terminal or command prompt.\\n2. Enter the following command to install PyTorch:\\n```\\npip install torch torchvision\\n```\\n1. Wait for the installation to complete.\\n\\nThat's it! You should now have PyTorch installed.\",\n",
       " 'history': [['How to install PyTorch?',\n",
       "   \"To install PyTorch, you can use pip, which is the package installer for Python. Here are the steps to install PyTorch using pip:\\n\\n1. Open a terminal or command prompt.\\n2. Enter the following command to install PyTorch:\\n```\\npip install torch torchvision\\n```\\n1. Wait for the installation to complete.\\n\\nThat's it! You should now have PyTorch installed.\"]],\n",
       " 'usage': {'usage': None, 'finish_reason': 'stop'}}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res = predictor.predict(\n",
    "    data={\n",
    "        \"prompt\": \"How to install PyTorch?\",\n",
    "        \"system_prompt\": \"Act like you are programmer with 5+ years of experience.\",\n",
    "        \"temperature\": 0.8,\n",
    "    }\n",
    ")\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在测试完成之后，需要删除部署的模型服务，避免产生额外的费用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.delete_service()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 微调ChatGLM2-6b模型\n",
    "\n",
    "\n",
    "PAI提供的这些大语言模型，支持用户使用私有数据进行微调训练。用户可以通过`RegisteredModel.get_estimator`方法，获取相应的模型微调算法，然后使用`Estimator.fit`方法，使用私有数据进行微调训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.model import RegisteredModel\n",
    "\n",
    "# 获取PAI提供的ChatGLM2-6b模型\n",
    "m = RegisteredModel(\"chatglm2-6b\", model_provider=\"pai\")\n",
    "\n",
    "# 获取模型的微调算法\n",
    "est = m.get_estimator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "查看模型的微调算法支持的超参配置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hp_defs = [[\"HyperParameter Name\", \"Type\", \"DefaultValue\", \"Required\", \"Description\"]]\n",
    "for ch in est.hyperparameter_definitions:\n",
    "    hp_defs.append(\n",
    "        [\n",
    "            ch[\"Name\"],\n",
    "            ch[\"Type\"],\n",
    "            est.hyperparameters.get(ch[\"Name\"], ch[\"DefaultValue\"]),\n",
    "            ch[\"Required\"],\n",
    "            ch[\"Description\"],\n",
    "        ]\n",
    "    )\n",
    "\n",
    "table_display(hp_defs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "通过`est.hyperparameters`和`est.set_hyperparameters`，用户可以查看和设置微调算法支持的超参配置。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'model': 'chatglm2', 'learning_rate': '1e-5', 'num_train_epochs': '5', 'per_device_train_batch_size': '16', 'max_seq_len': '512', 'lora_dim': '64', 'gradient_accumulation_steps': 1}\n",
      "{'model': 'chatglm2', 'learning_rate': '1e-5', 'num_train_epochs': '5', 'per_device_train_batch_size': 8, 'max_seq_len': '512', 'lora_dim': '64', 'gradient_accumulation_steps': 1}\n"
     ]
    }
   ],
   "source": [
    "# 查看当前的算法超参配置\n",
    "print(est.hyperparameters)\n",
    "\n",
    "# 设置算法的batch_size\n",
    "est.set_hyperparameters(per_device_train_batch_size=8)\n",
    "\n",
    "# 查看算法超参配置\n",
    "print(est.hyperparameters)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PAI提供微调的实例数据集，用户可以通过`model.get_estimator_inputs()`方法获取PAI微调算法支持的输入，用户可以基于这些输入，测试算法，查看数据然后构建自己的训练数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'model': 'oss://pai-quickstart-cn-hangzhou.oss-cn-hangzhou-internal.aliyuncs.com/huggingface/models/chatglm2-6b/main/',\n",
      " 'train': 'oss://pai-quickstart-cn-hangzhou.oss-cn-hangzhou-internal.aliyuncs.com/huggingface/datasets/llm_instruct/ch_poetry_train.json',\n",
      " 'validation': 'oss://pai-quickstart-cn-hangzhou.oss-cn-hangzhou-internal.aliyuncs.com/huggingface/datasets/llm_instruct/ch_poetry_test.json'}\n"
     ]
    }
   ],
   "source": [
    "from pprint import pprint\n",
    "\n",
    "# 查看模型自动的默认训练输入信息\n",
    "training_inputs = m.get_estimator_inputs()\n",
    "\n",
    "pprint(training_inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "模型提供多个算法默认输入，包括以下：\n",
    "\n",
    "- `model`: 训练算法使用的预训练模型，当前示例中是`ChatGLM2-6b`模型的OSS路径。\n",
    "- `train`: 算法使用的训练数据集\n",
    "- `validation`: 算法使用的验证数据集\n",
    "\n",
    "\n",
    "这些数据集或是模型由PAI提供在公共读的OSS Bucket上，可以被挂载到训练作业中，供训练算法代码使用，用户也可以下载相应的数据集到本地，查看数据集的格式。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.common.oss_utils import download\n",
    "\n",
    "from pai.session import get_default_session\n",
    "\n",
    "region_id = get_default_session().region_id\n",
    "\n",
    "\n",
    "# 下载训练数据集到本地\n",
    "local_path = download(\n",
    "    f\"oss://pai-quickstart-{region_id}/huggingface/datasets/llm_instruct/ch_poetry_train.json\",\n",
    "    \"./train_data/\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "大语言模型的微调训练任务，支持通过`JSON`的形式提供数据集，每一行是一个`Dict`对象，包含`instruction`和`output`两个字段，其中`instruction`是输入的文本，`output`是对应的输出文本。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[\n",
      "    {\n",
      "        \"instruction\": \"写一首以“戌辰年孟冬月诗四首选一”为题的诗：\",\n",
      "        \"output\": \"卜居白下傍烟霞，送别情怀乱似麻。何日游踪重北上，并门犹有旧时家。\"\n",
      "    },\n",
      "    {\n",
      "        \"instruction\": \"写一首以“花甲述怀”为题的诗：\",\n",
      "        \"output\": \"屈指悬弧日，匆匆六十年。诗书敦夙契，风月结前缘。心有千秋志，囊无一个钱。苍松撑傲骨，绝不受人怜。\"\n",
      "    },\n",
      "    {\n",
      "        \"instruction\": \"写一首以“寓居夜感”为题的诗：\",\n",
      "        \"output\": \"独坐晚凉侵，客窗秋意深。风沙疑化雾，夜气欲成霖。家务劳人倦，浓茶代酒斟。哀鸿鸣四野，一并助长吟。\"\n",
      "    },\n",
      "    {\n",
      "        \"instruction\": \"写一首以“戏咏灵猫 其一”为题的诗：\",\n",
      "        \"output\": \"灵猫雅号是金狮，水样柔情山样姿。早晚门前迎送我，撒娇投抱更依依。\"\n",
      "    },\n",
      "    {\n",
      "        \"instruction\": \"写一首以“戏咏灵猫 其二”为题的诗：\",\n",
      "        \"output\": \"金狮淘气喜翻书，对镜窥容恐不如。深愧家贫亏待了，眠无暖榻食无鱼。\"\n",
      "    },\n",
      "    {\n",
      "        \"instruction\": \"写一首以“次答友人思乡诗”为题的诗：\",\n",
      "        \"output\": \"阅尽沧桑万事空，何如归卧夕阳中。并州最是伤心地，四十馀年噩梦同。\"\n",
      "    },\n",
      "    {\n",
      "        \"instruction\": \"写一首以“南归杂感 其一”为题的诗：\",\n",
      "        \"output\": \"一到湘中耳目新，稻苗翻浪草铺茵。烟波无限江南趣，风柳非同塞北春。车辙移时山掠影，市声起处路扬尘。近乡莫怪行偏缓，屈指知交剩几人。\"\n",
      "    },\n",
      "    {\n"
     ]
    }
   ],
   "source": [
    "!head -n 30 train_data/ch_poetry_train.json"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在此，我们将使用默认的数据集，提交一个模型微调训练作业。SDK默认会打印训练作业在PAI控制台上的链接，用户可以通过该链接查看训练作业的状态和日志。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "est.fit(inputs=training_inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`est.fit` 默认会打印训练作业日志，并等待到训练作业完成（成功，失败，或是被终止）。训练作业运行成功之后，用户可以通过查看`est.model_data()`查看训练作业的输入模型路径。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(est.model_data())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
